package hotnet.filter;

import hotnet.session.DummySession;
import hotnet.session.IoSession;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class OrderedThreadPoolExecutor extends ThreadPoolExecutor {
	private static final String TASKS_QUEUE = OrderedThreadPoolExecutor.class.getName() + ":TASKQUEUE";
    private static final IoSession EXIT_SIGNAL = new DummySession();

	private static final int DEFAULT_CORE_POOL_SIZE = 0;
	private static final int DEFAULT_MAX_POOL_SIZE = 16;
	private static final int DEFAULT_KEEP_ALIVE = 30; // 30S
	
    private final BlockingQueue<IoSession> waitingSessions = new LinkedBlockingQueue<IoSession>();
    private final Set<Worker> workers = new HashSet<Worker>();
    private final AtomicInteger idleWorkers = new AtomicInteger();
    private volatile int largestPoolSize;
    
    private volatile boolean shutdown;

    public OrderedThreadPoolExecutor() {
    	this(DEFAULT_CORE_POOL_SIZE, DEFAULT_MAX_POOL_SIZE, DEFAULT_KEEP_ALIVE, 
    			TimeUnit.SECONDS);
    }
	public OrderedThreadPoolExecutor(int corePoolSize, int maximumPoolSize,
			long keepAliveTime, TimeUnit unit) {
		super(corePoolSize, maximumPoolSize, keepAliveTime, unit, 
				new SynchronousQueue<Runnable>(),
				new AbortPolicy());
	}
	
	private void addWorker() {
		synchronized (workers) {
			if (workers.size() >= super.getMaximumPoolSize()) {
				return;
			}
			
			Worker worker = new Worker();
			Thread thread = getThreadFactory().newThread(worker);
			idleWorkers.incrementAndGet();
			
			thread.start();
			workers.add(worker);
			
			if (workers.size() > largestPoolSize) {
                largestPoolSize = workers.size();
            }
		}
	}
	
	private void removeWorker() {
		synchronized (workers) {
			if (workers.size() <= getCorePoolSize()) {
				return;
			}
			
			// RANDOM worker poll EXIT_SIGNAL will exit
			waitingSessions.offer(EXIT_SIGNAL);
		}
	}
	
	private void addWorkerIfNecessary() {
		if (idleWorkers.get() == 0) {
			synchronized (workers) {
				if (idleWorkers.get() == 0 || workers.isEmpty()) {
					addWorker();
				}
			}
		}
	}
	
	@Override
	public int getMaximumPoolSize() {
		return super.getMaximumPoolSize();
	}
	
	@Override
	public void setMaximumPoolSize(int maximumPoolSize) {
		if (getCorePoolSize() > maximumPoolSize) {
			throw new IllegalArgumentException("maximumPoolSize: " + maximumPoolSize );
		}
		
		synchronized (workers) {
			if (workers.size() > maximumPoolSize) {
				int diff = workers.size() - maximumPoolSize;
				
				for (int i = 0; i < diff; i++) {
					
					// because synchronized, any remove will wait internal lock
					removeWorker();
				}
			}
		}
	}
	
	@Override
	public List<Runnable> shutdownNow() {
		shutdown();
		
		List<Runnable> answer = new ArrayList<Runnable>();
        IoSession session;
        
        while ((session = waitingSessions.poll()) != null) {
        	if (session == EXIT_SIGNAL) {
        		waitingSessions.offer(EXIT_SIGNAL);
        		
        		Thread.yield();
        	} else {
        		SessionTasksQueue sessionTasksQueue = (SessionTasksQueue)session.getAttribute(TASKS_QUEUE, null);
        		
        		synchronized (sessionTasksQueue.taskQueue) {

                    for (Runnable task : sessionTasksQueue.taskQueue) {
                        answer.add(task);
                    }

                    sessionTasksQueue.taskQueue.clear();
                }
        		
        	}
        }
        
        return answer;
        
	}
	
	@Override
	public boolean isTerminated() {
        if (!shutdown) {
            return false;
        }

        synchronized (workers) {
            return workers.isEmpty();
        }
	}
	
    @Override
    public boolean isShutdown() {
        return shutdown;
    
    }
	
	@Override
	public void shutdown() {
		if (shutdown) 
			return;
		
		shutdown = true;
		
		synchronized (workers) {
			for (int i = 0; i < workers.size(); i++)
				waitingSessions.offer(EXIT_SIGNAL);
		}
	}
	
	@Override
	public void execute(Runnable task) {
		if (shutdown) {
			getRejectedExecutionHandler().rejectedExecution(task, this);
		}
		IoEvent event = null;
		
		if (!(task instanceof IoEvent)) {
			throw new IllegalArgumentException("task must be an IoEvent or its subclass.");
		} else {
			event = (IoEvent)task;
		}
		
		IoSession session = event.getSession();
		SessionTasksQueue sessionTasksQueue = getSessionTasksQueue(session);
		Queue<Runnable> tasksQueue = sessionTasksQueue.taskQueue;
		boolean needOfferToWaitQueue = false;
		
		synchronized (tasksQueue) {
			tasksQueue.add(event);
			
			if (sessionTasksQueue.processingCompleted) {
				sessionTasksQueue.processingCompleted = false;
				needOfferToWaitQueue = true;
			}
			
			if (needOfferToWaitQueue) {
				waitingSessions.offer(session);
			}
			
			addWorkerIfNecessary();
		}
		
	}
	
	private SessionTasksQueue getSessionTasksQueue(IoSession session) {
        SessionTasksQueue queue = (SessionTasksQueue) session.getAttribute(TASKS_QUEUE, null);

        if (queue == null) {
            queue = new SessionTasksQueue();
            SessionTasksQueue oldQueue = (SessionTasksQueue) session.setAttributeIfAbsent(TASKS_QUEUE, queue);

            if (oldQueue != null) {
                queue = oldQueue;
            }
        }

        return queue;
    }
	
	private class SessionTasksQueue {
		private final Queue<Runnable> taskQueue = new ConcurrentLinkedQueue<Runnable>();
		// tell whether processing
		private boolean processingCompleted = true;	
	}
	

	class Worker implements Runnable {

		@Override
		public void run() {
			try {
				while (true) {
					IoSession session = fetchSession();
					
					idleWorkers.decrementAndGet();
					
					if (session == null) {
						synchronized (workers) {
							if (workers.size() > getCorePoolSize()) {
								workers.remove(this);
								
								break;
							}
						}
					}
					
					if (session == EXIT_SIGNAL) {
						break;
					}
					
					// void occur exception
					try {
						if (session != null) {
							runTasks(getSessionTasksQueue(session));
						}
					} finally {
						idleWorkers.incrementAndGet();
					}
				}
			} finally {
				synchronized (workers) {
					workers.remove(this);
					workers.notifyAll();
				}
			}
		}
		
		private IoSession fetchSession() {
			IoSession session = null;
			long currentTime = System.currentTimeMillis();
			long deadlineTime = currentTime + getKeepAliveTime(TimeUnit.MILLISECONDS);
			
			for (;;) {
				long waitTime = deadlineTime - System.currentTimeMillis();
				
				if (waitTime <= 0)
					break;

				try {
					session = waitingSessions.poll(waitTime, TimeUnit.MILLISECONDS);
					
					break;
				} catch(InterruptedException e) {
					
					continue;
				}
			}
			
			return session;
		}
		
		private void runTasks(SessionTasksQueue sessionTasksQueue) {
			for (;;) {
				Runnable task;
				Queue<Runnable> tasksQueue = sessionTasksQueue.taskQueue;
				
				synchronized (tasksQueue) {
                    task = tasksQueue.poll();

                    if (task == null) {
                        sessionTasksQueue.processingCompleted = true;
                        
                        break;
                    }
                }
				
				runTask(task);
				
			}
		}
		
		private void runTask(Runnable task) {
			task.run();
		}
		
	}
}
